package spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
 * @Author Jeremy Zheng
 * @Date 2021/3/18 16:27
 * @Version 1.0
 * 自定义聚合函数类
 */
object SparkSQL03_UDAF_Demo1 {

  def main(args: Array[String]): Unit = {
    //配置spark运行环境
    val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("UDF")
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    val df: DataFrame = spark.read.json("dataSet/test.json")
    df.createOrReplaceTempView("user")

    spark.udf.register("ageAvg",new MyAvgUDAF())

    spark.sql("select ageAvg(age) from user").show()

    //关闭资源
    spark.close()
  }

  //自定义聚合函数类：计算年龄的平均值
  class MyAvgUDAF extends UserDefinedAggregateFunction{

    //输入数据的结构
    override def inputSchema: StructType = {
      StructType(
        Array(
          StructField("age",LongType)
        )
      )

    }

    //缓冲区数据的结构
    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total",LongType),
          StructField("count",LongType)
        )
      )
    }

    //函数计算结果的类型
    override def dataType: DataType = LongType

    //函数的稳定性
    override def deterministic: Boolean = true

    //缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //buffer(0)=0L
      //buffer(1)=0L

      buffer.update(0,0L)
      buffer.update(1,0L)
    }

    //根据输入的值来更新缓冲区数据
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0,buffer.getLong(0)+input.getLong(0))
      buffer.update(1,buffer.getLong(1)+1)
    }

    //缓冲区合并（因为是分布式计算）
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
      buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
    }

    //计算平均值
    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0)/buffer.getLong(1)
    }

  }
}


